import torch.nn as nn
import torch.nn.functional as F
import torch
from typing import Optional


class CausalSelfAttention(nn.Module):
    def __init__(self, n_embed: int, n_head: int, p_drop: float) -> None:
        super().__init__()
        assert n_embed % n_head == 0
        self.p_drop = p_drop
        self.n_embed = n_embed
        self.n_head = n_head
        self.fc_attn = nn.Linear(n_embed, n_embed * 3)
        self.fc_out = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(p_drop)
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()
        q, k, v = self.fc_attn(x).chunk(3, dim=-1)
        q = q.reshape(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
        k = k.reshape(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
        v = v.reshape(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
        x = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        x = x.permute(0, 2, 1, 3).reshape(B, T, C)
        x = self.fc_out(x)
        x = self.dropout(x)
        return x
    

class MLP(nn.Module):
    def __init__(self, n_embed: int, p_drop: float) -> None:
        super().__init__()
        self.n_embed = n_embed
        self.p_drop = p_drop
        self.fc_in = nn.Linear(n_embed, n_embed * 4)
        self.fc_out = nn.Linear(n_embed * 4, n_embed)
        self.act_fn = nn.GELU()
        self.dropout = nn.Dropout(p_drop)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc_in(x)
        x = self.act_fn(x)
        x = self.fc_out(x)
        x = self.dropout(x)
        return x
    
    
class Block(nn.Module):
    def __init__(self, n_embed: int, n_head: int, p_drop: float) -> None:
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.p_drop = p_drop
        self.ln_attn = nn.LayerNorm(n_embed)
        self.ln_mlp = nn.LayerNorm(n_embed)
        self.attn = CausalSelfAttention(n_embed, n_head, p_drop)
        self.mlp = MLP(n_embed, p_drop)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln_attn(x))
        x = x + self.mlp(self.ln_mlp(x))
        return x
    
    
    
class GPT(nn.Module):
    def __init__(self, n_embed: int, 
                 n_head: int, 
                 p_drop: float,  
                 n_block: int, 
                 block_size: int,
                 vocab_size: int):
        """
        Args:
            n_embed (int): 词元嵌入维度
            n_head (int): 多头注意力头数
            p_drop (float): dropout概率
            vocab_size (int): 词汇表大小
            n_block (int): 注意力单元层数
            block_size (int): 注意力单元输入序列长度
        """
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict(dict(
            te = nn.Embedding(vocab_size, n_embed),
            pe = nn.Embedding(block_size, n_embed),
            blocks = nn.ModuleList([Block(n_embed, n_head, p_drop) for _ in range(n_block)]),
            ln = nn.LayerNorm(n_embed),
            dropout = nn.Dropout(p_drop)
        ))
        self.lm_head = nn.Linear(n_embed, vocab_size)
        self.apply(self._init_weight)
    
        
    def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor:
        B, T = input_ids.size()
        assert T <= self.block_size, f"最大序列长度为{self.block_size}"
        x = self.transformer.te(input_ids) + self.transformer.pe(torch.arange(T, device=input_ids.device))
        x = self.transformer.dropout(x)
        for block in self.transformer.blocks:
            x = block(x)
        x = self.transformer.ln(x)
        if labels is not None:
            # if we are given some desired labels also calculate the loss
            logits = self.lm_head(x)
            logits = logits[:, :-1, :]
            labels = labels[:, 1:]
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None
            
        return logits, loss

    
    def _init_weight(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
            
    @torch.no_grad()
    def generate(self, ids: torch.Tensor, n_token: int, temprature: float = 1.0, top_k: Optional[int] = None):
        """推理生成词元

        Args:
            ids (torch.Tensor): 输入的词元id
            n_token (int): 要求生成的词元数
            temprature (float, optional): 温度系数. Defaults to 1.0.
            top_k (Optional[int], optional): 词表中选取的前几名. Defaults to None.
        """
        for _ in range(n_token):
            B, T = ids.shape
            ids = ids if T <= self.block_size else ids[:, - self.block_size:]
            logits, _ = self(ids)
            logits = logits[:, -1, :] / temprature # batch_size, vocab_size
            if top_k is not None:
                v, _ = torch.topk(logits, k=top_k, dim=-1) # v: batch_size, top_k
                logits[logits < v[:, [-1]]] = - float('Inf')
            scores = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(scores, num_samples=1)
            ids = torch.cat([ids, next_id], dim=1)
        return ids